{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reference\n",
    "* 李宏毅老師投影片\n",
    "* Pytorch document"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attention 說明"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./image/seq2seq.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./image/atten1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./image/attention.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pytorch Flow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./image/atten_flow.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AttnDecoderRNN(nn.Module):\n",
    "    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=5):\n",
    "        super(AttnDecoderRNN, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.hidden_size = hidden_size\n",
    "        self.output_size = output_size\n",
    "        self.dropout_p = dropout_p\n",
    "        self.max_length = max_length\n",
    "\n",
    "        self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n",
    "        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n",
    "        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n",
    "        self.dropout = nn.Dropout(self.dropout_p)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n",
    "        self.out = nn.Linear(self.hidden_size, self.output_size)\n",
    "\n",
    "    def forward(self, input, hidden, encoder_outputs):\n",
    "        embedded = self.embedding(input).view(1, 1, -1)\n",
    "        embedded = self.dropout(embedded)\n",
    "        \n",
    "        \n",
    "        \n",
    "        #########################################\n",
    "        # 第一步： calculate weights !\n",
    "        # match fun. : Applies a linear transformation to the incoming data: y = Ax + b\n",
    "        # self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n",
    "        # using the decoder’s input and hidden state as inputs\n",
    "        o = torch.cat((embedded[0], hidden[0]), 1)\n",
    "        o = self.attn(o)\n",
    "        attn_weights = F.softmax(o, dim=1)\n",
    "                      \n",
    "                      \n",
    "                      \n",
    "        \n",
    "        # attn_weights = F.softmax(\n",
    "        #                 self.attn(torch.cat((embedded[0], hidden[0]), 1)), \n",
    "        #                          dim=1)\n",
    "        \n",
    "        \n",
    "        \n",
    "        ##########################################\n",
    "        # 第二步：encoder_outputs 和  softmax運算後的權重相乘。\n",
    "        # torch.bmm: Performs a batch matrix-matrix product of matrices.\n",
    "        attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n",
    "                                 encoder_outputs.unsqueeze(0))\n",
    "        \n",
    "        \n",
    "\n",
    "        # 合成 Decoder input \n",
    "        output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
    "        output = self.attn_combine(output).unsqueeze(0)\n",
    "        \n",
    "        \n",
    "        #############################################\n",
    "        # Training RNN\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "\n",
    "        output = F.log_softmax(self.out(output[0]), dim=1)\n",
    "        return output, hidden, attn_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 舉例：image caption\n",
    "![](./image/image_caption.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
